# -*- coding: utf-8 -*-
from tkinter import *
from tool import common
from train import trainFn
from train import trainVgg
from train import trainResnet
from train import predictFn
from train import loadmodel
from train import pruning
import threading


def center_window(root, width, height):
    screenwidth = root.winfo_screenwidth()
    screenheight = root.winfo_screenheight()
    size = '%dx%d+%d+%d' % (width, height, (screenwidth - width) / 2, (screenheight - height) / 2)
    root.geometry(size)


def frozenGraph_tic_type():
    trainFn.frozenGraph_tic_type(metaName=metaName_tic_type.get())


def frozenGraph_tic_length():
    trainFn.frozenGraph_tic_length(metaName=metaName_tic_length.get())



root = Tk()
root.wm_attributes('-topmost', 1)
root.title("训练系统")
center_window(root, 600, 400)
root.resizable(width=False, height=False)

metaName_tic_type = StringVar(value='model.ckpt-10000.meta')
metaName_tic_length = StringVar(value='model.ckpt-15000.meta')
metaName_axle_type = StringVar(value='model.ckpt-15000.meta')

padx = 5
pady = 5
frame_top = Frame(root, padx=padx, pady=pady)
Button(frame_top, text="清空log", command=common.clearFile_log).pack(side=LEFT)
Button(frame_top, text="启动tensorboard", command=common.start_tensorboard_thread).pack(side=LEFT)
Button(frame_top, text="打开chrome", command=common.open_chrome_thread).pack(side=LEFT)
frame_top.pack(side=BOTTOM)

frame1 = LabelFrame(root, text="TIC 类型", padx=padx, pady=pady)
Button(frame1, text="训练数据转换(分割)", command=common.dataProcess_tic_type.convertData_train_split).pack(anchor='w')
Button(frame1, text="训练数据转换", command=common.dataProcess_tic_type.convertData_train).pack(anchor='w')
Button(frame1, text="测试数据转换", command=common.dataProcess_tic_type.convertData_test).pack(anchor='w')
Button(frame1, text="训练", command=trainFn.tic_train_type_thread).pack(anchor='w')
Button(frame1, text="keras训练", command=trainVgg.tic_train_type_thread).pack(anchor='w')
Button(frame1, text="预测", command=predictFn.predict_tic_type).pack(anchor='w')
Button(frame1, text="保存参数", command=frozenGraph_tic_type).pack(anchor='w')
Button(frame1, text="pruning", command=pruning.pruning_thread).pack(anchor='w')
Entry(frame1, textvariable=metaName_tic_type).pack(anchor='w')
frame1.pack(side=LEFT, anchor='n')

frame2 = LabelFrame(root, text="TIC 长度", padx=padx, pady=pady)
Button(frame2, text="训练数据转换(分割)", command=common.dataProcess_tic_length.convertData_train_split).pack(anchor='w')
Button(frame2, text="训练数据转换", command=common.dataProcess_tic_length.convertData_train).pack(anchor='w')
Button(frame2, text="测试数据转换", command=common.dataProcess_tic_length.convertData_test).pack(anchor='w')
Button(frame2, text="训练", command=trainFn.tic_train_length_thread).pack(anchor='w')
Button(frame2, text="保存参数", command=frozenGraph_tic_length).pack(anchor='w')
Entry(frame2, textvariable=metaName_tic_length).pack(anchor='w')
frame2.pack(side=LEFT, anchor='n')

frame3 = LabelFrame(root, text="LWH 长度", padx=padx, pady=pady)
Button(frame3, text="训练数据转换", command=common.dataProcess_lwh_length.convertData_train).pack(anchor='w')
Button(frame3, text="测试数据转换", command=common.dataProcess_lwh_length.convertData_test).pack(anchor='w')
Button(frame3, text="训练", command=trainFn.lwh_train_length_thread).pack(anchor='w')
frame3.pack(side=LEFT, anchor='n')

frame4 = LabelFrame(root, text="轴型", padx=padx, pady=pady)
Button(frame4, text="训练数据转换", command=common.dataProcess_axle_type.convertData_train).pack(anchor='w')
Button(frame4, text="测试数据转换", command=common.dataProcess_axle_type.convertData_test).pack(anchor='w')
Button(frame4, text="训练", command=trainFn.axle_train_type_thread).pack(anchor='w')
Button(frame4, text="保存参数", command=trainFn.frozenGraph_axle_type).pack(anchor='w')
Entry(frame4, textvariable=metaName_axle_type).pack(anchor='w')
Button(frame4, text="keras训练", command=trainVgg.axle_train_type_thread).pack(anchor='w')
Button(frame4, text="pruning", command=pruning.pruning_axle_thread).pack(anchor='w')
Button(frame4, text="resnet训练", command=trainResnet.axle_train_type_thread).pack(anchor='w')
frame4.pack(side=LEFT, anchor='n')

root.mainloop()
